Supplementary Appendix

This is a supplementary appendix to the research paper Endogenous Macrodynamics in Algorithmic Recourse. It contains all of the experimental results, including those not highlighted in the actual paper. It also contains additional information about the proposed counterfactual generators.

Experimental Results

Synthetic Data

This notebook was used to run the experiments for the synthetic datasets. In the following we first run the experiments and then generate visualizations and tables.

Experiments

Code
using Pkg; Pkg.activate("dev")
Code
include("dev/utils.jl")
using AlgorithmicRecourseDynamics
using CounterfactualExplanations, Flux, Plots, PlotThemes, Random, LaplaceRedux, LinearAlgebra, Images
theme(:wong)
output_path = output_dir("synthetic")
www_path = www_dir("synthetic");
Code
max_obs = 1000
catalogue = AlgorithmicRecourseDynamics.Data.load_synthetic(max_obs)
choices = [
    :linearly_separable, 
    :overlapping, 
    :circles, 
    :moons,
]
data_sets = filter(p -> p[1] in choices, catalogue)
Code
models = [
    :LogisticRegression, 
    :FluxModel, 
    :FluxEnsemble,
]
generators = Dict(
    :Greedy=>GreedyGenerator(), 
    :Generic=>GenericGenerator(),
    :REVISE=>REVISEGenerator(),
    :DICE=>DiCEGenerator(),
)
Code
experiments = set_up_experiments(data_sets,models,generators)
Code
using AlgorithmicRecourseDynamics.Models: model_evaluation
plts = []
for (exp_name, exp_) in experiments
    for (M_name, M) in exp_.models
        score = round(model_evaluation(M, exp_.test_data),digits=2)
        plt = plot(M, exp_.test_data, title="$exp_name;\n $M_name ($score)")
        # Errors:
        ids = findall(vec(round.(probs(M, exp_.test_data.X)) .!= exp_.test_data.y))
        x_wrongly_labelled = exp_.test_data.X[:,ids]
        scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
        plts = vcat(plts..., plt)
    end
end
plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
savefig(plt, joinpath(www_path,"models_test_before.png"))

Figure 1 shows the test data before running the experiment.

Code
load(joinpath(www_path,"models_test_before.png"))

Code
using AlgorithmicRecourseDynamics.Models: model_evaluation
plts = []
for (exp_name, exp_) in experiments
    for (M_name, M) in exp_.models
        score = round(model_evaluation(M, exp_.train_data),digits=2)
        plt = plot(M, exp_.train_data, title="$exp_name;\n $M_name ($score)")
        # Errors:
        ids = findall(vec(round.(probs(M, exp_.train_data.X)) .!= exp_.train_data.y))
        x_wrongly_labelled = exp_.train_data.X[:,ids]
        scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
        plts = vcat(plts..., plt)
    end
end
plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
savefig(plt, joinpath(www_path,"models_train_before.png"))

Figure 2 shows the training data before running the experiment.

Code
load(joinpath(www_path,"models_train_before.png"))

Code
n_evals = 5
n_rounds = 50
evaluate_every = Int(round(n_rounds/n_evals))
n_folds = 5
n_bootstrap = 1
T = 100
using Serialization
results = run_experiments(
    experiments;
    save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, n_bootstrap=n_bootstrap, T=T
)
Serialization.serialize(joinpath(output_path,"results.jls"),results)
Code
using AlgorithmicRecourseDynamics.Models: model_evaluation
plot_dict = Dict(key => Dict() for (key,val) in results)
fold = 1
for (name, res) in results
    exp_ = res.experiment
    plot_dict[name] = Dict(key => [] for (key,val) in exp_.generators)
    rec_sys = exp_.recourse_systems[fold]
    sys_ids = collect(exp_.system_identifiers)
    M = length(rec_sys)
    for m in 1:M
        model_name, generator_name = sys_ids[m]
        M = rec_sys[m].model
        score = round(model_evaluation(M, exp_.test_data),digits=2)
        plt = plot(M, exp_.test_data, title="$name;\n $model_name ($score)")
        # Errors:
        ids = findall(vec(round.(probs(M, exp_.test_data.X)) .!= exp_.test_data.y))
        x_wrongly_labelled = exp_.test_data.X[:,ids]
        scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
        plot_dict[name][generator_name] = vcat(plot_dict[name][generator_name], plt)
    end
end
plot_dict = Dict(key => reduce(vcat, [plots[key] for plots in values(plot_dict)]) for (key, value) in generators)
for (name, plts) in plot_dict
    plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
    savefig(plt, joinpath(www_path,"models_test_after_$(name).png"))
end

Figure 3 shows the test data after running the experiment.

Code
img_files = readdir(www_path)[contains.(readdir(www_path),"models_test_after")]
img_files = joinpath.(www_path,img_files)
for img in img_files
    display(load(img))
end

(a) DICE

(b) Generic

(c) Greedy

(d) Latent

Code
using AlgorithmicRecourseDynamics.Models: model_evaluation
plot_dict = Dict(key => Dict() for (key,val) in results)
fold = 1
for (name, res) in results
    exp_ = res.experiment
    plot_dict[name] = Dict(key => [] for (key,val) in exp_.generators)
    rec_sys = exp_.recourse_systems[fold]
    sys_ids = collect(exp_.system_identifiers)
    M = length(rec_sys)
    for m in 1:M
        model_name, generator_name = sys_ids[m]
        M = rec_sys[m].model
        data = rec_sys[m].data
        score = round(model_evaluation(M, data),digits=2)
        plt = plot(M, data, title="$name;\n $model_name ($score)")
        # Errors:
        ids = findall(vec(round.(probs(M, data.X)) .!= data.y))
        x_wrongly_labelled = data.X[:,ids]
        scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
        plot_dict[name][generator_name] = vcat(plot_dict[name][generator_name], plt)
    end
end
plot_dict = Dict(key => reduce(vcat, [plots[key] for plots in values(plot_dict)]) for (key, value) in generators)
for (name, plts) in plot_dict
    plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
    savefig(plt, joinpath(www_path,"models_train_after_$(name).png"))
end

Figure 4 shows the training data after running the experiment.

Code
img_files = readdir(www_path)[contains.(readdir(www_path),"models_train_after")]
img_files = joinpath.(www_path,img_files)
for img in img_files
    display(load(img))
end

(a) DICE

(b) Generic

(c) Greedy

(d) Latent

Plots

Code
using Serialization
results = Serialization.deserialize(joinpath(output_path,"results.jls"))
Dict{Symbol, ExperimentResults} with 4 entries:
  :overlapping        => ExperimentResults(2520×13 DataFrame…
  :linearly_separable => ExperimentResults(2520×13 DataFrame…
  :circles            => ExperimentResults(2520×13 DataFrame…
  :moons              => ExperimentResults(2520×13 DataFrame
Code
using Images
line_charts = Dict()
errorbar_charts = Dict()
for (data_name, res) in results
    plt = plot(res)
    Images.save(joinpath(www_path, "line_chart_$(data_name).png"), plt)
    line_charts[data_name] = plt
    plt = plot(res,maximum(res.output.n))
    Images.save(joinpath(www_path, "errorbar_chart_$(data_name).png"), plt)
    errorbar_charts[data_name] = plt
end

Line Charts

Figure 5 shows the evolution of the evaluation metrics over the course of the experiment.

Code
img_files = readdir(www_path)[contains.(readdir(www_path),"line_chart")]
img_files = joinpath.(www_path,img_files)
for img in img_files
    display(load(img))
end

(a) Circles

(b) Linearly Separable

(c) Moons

(d) Overlapping

Error Bar Charts

Figure 6 shows the evaluation metrics at the end of the experiments.

Code
img_files = readdir(www_path)[contains.(readdir(www_path),"errorbar_chart")]
img_files = joinpath.(www_path,img_files)
for img in img_files
    display(load(img))
end

(a) Circles

(b) Linearly Separable

(c) Moons

(d) Overlapping

Tables

?@tbl-results shows a summary of all results.

Code
using AlgorithmicRecourseDynamics: kable
kable(results, [50]; format="html")

?(caption)

<table>
 <thead>
  <tr>
   <th style="text-align:center;"> dataset </th>
   <th style="text-align:center;"> n </th>
   <th style="text-align:center;"> model </th>
   <th style="text-align:center;"> generator </th>
   <th style="text-align:center;"> decisiveness </th>
   <th style="text-align:center;"> disagreement </th>
   <th style="text-align:center;"> mmd_domain </th>
   <th style="text-align:center;"> mmd_grid </th>
   <th style="text-align:center;"> mmd_model </th>
   <th style="text-align:center;"> model_performance </th>
   <th style="text-align:center;"> perturbation </th>
  </tr>
 </thead>
<tbody>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> overlapping </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxEnsemble </td>
   <td style="text-align:center;"> DICE </td>
   <td style="text-align:center;"> 1.169 (0.090) </td>
   <td style="text-align:center;"> 0.142 (0.006) </td>
   <td style="text-align:center;"> 0.118 (0.005) </td>
   <td style="text-align:center;"> 0.019 (0.004) </td>
   <td style="text-align:center;"> 0.045 (0.003) </td>
   <td style="text-align:center;"> -0.073 (0.007) </td>
   <td style="text-align:center;"> 0.365 (0.076) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> overlapping </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxEnsemble </td>
   <td style="text-align:center;"> Generic </td>
   <td style="text-align:center;"> 1.171 (0.091) </td>
   <td style="text-align:center;"> 0.145 (0.008) </td>
   <td style="text-align:center;"> 0.117 (0.004) </td>
   <td style="text-align:center;"> 0.025 (0.005) </td>
   <td style="text-align:center;"> 0.038 (0.006) </td>
   <td style="text-align:center;"> -0.076 (0.005) </td>
   <td style="text-align:center;"> 0.370 (0.088) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> overlapping </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxEnsemble </td>
   <td style="text-align:center;"> Greedy </td>
   <td style="text-align:center;"> 1.048 (0.144) </td>
   <td style="text-align:center;"> 0.162 (0.010) </td>
   <td style="text-align:center;"> 0.097 (0.001) </td>
   <td style="text-align:center;"> 0.037 (0.006) </td>
   <td style="text-align:center;"> 0.051 (0.011) </td>
   <td style="text-align:center;"> -0.092 (0.008) </td>
   <td style="text-align:center;"> 0.325 (0.051) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> overlapping </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxEnsemble </td>
   <td style="text-align:center;"> REVISE </td>
   <td style="text-align:center;"> 0.812 (0.127) </td>
   <td style="text-align:center;"> 0.128 (0.012) </td>
   <td style="text-align:center;"> 0.116 (0.007) </td>
   <td style="text-align:center;"> 0.015 (0.001) </td>
   <td style="text-align:center;"> 0.028 (0.005) </td>
   <td style="text-align:center;"> -0.059 (0.010) </td>
   <td style="text-align:center;"> 0.252 (0.040) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> overlapping </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxModel </td>
   <td style="text-align:center;"> DICE </td>
   <td style="text-align:center;"> 1.060 (0.119) </td>
   <td style="text-align:center;"> 0.153 (0.012) </td>
   <td style="text-align:center;"> 0.113 (0.006) </td>
   <td style="text-align:center;"> 0.029 (0.006) </td>
   <td style="text-align:center;"> 0.046 (0.007) </td>
   <td style="text-align:center;"> -0.074 (0.007) </td>
   <td style="text-align:center;"> 0.411 (0.088) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> overlapping </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxModel </td>
   <td style="text-align:center;"> Generic </td>
   <td style="text-align:center;"> 1.120 (0.107) </td>
   <td style="text-align:center;"> 0.150 (0.007) </td>
   <td style="text-align:center;"> 0.114 (0.007) </td>
   <td style="text-align:center;"> 0.029 (0.008) </td>
   <td style="text-align:center;"> 0.044 (0.003) </td>
   <td style="text-align:center;"> -0.070 (0.006) </td>
   <td style="text-align:center;"> 0.401 (0.111) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> overlapping </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxModel </td>
   <td style="text-align:center;"> Greedy </td>
   <td style="text-align:center;"> 0.973 (0.220) </td>
   <td style="text-align:center;"> 0.164 (0.020) </td>
   <td style="text-align:center;"> 0.096 (0.002) </td>
   <td style="text-align:center;"> 0.039 (0.008) </td>
   <td style="text-align:center;"> 0.049 (0.011) </td>
   <td style="text-align:center;"> -0.088 (0.019) </td>
   <td style="text-align:center;"> 0.342 (0.051) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> overlapping </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxModel </td>
   <td style="text-align:center;"> REVISE </td>
   <td style="text-align:center;"> 0.796 (0.200) </td>
   <td style="text-align:center;"> 0.133 (0.012) </td>
   <td style="text-align:center;"> 0.116 (0.005) </td>
   <td style="text-align:center;"> 0.017 (0.001) </td>
   <td style="text-align:center;"> 0.031 (0.007) </td>
   <td style="text-align:center;"> -0.050 (0.010) </td>
   <td style="text-align:center;"> 0.250 (0.050) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> overlapping </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> LogisticRegression </td>
   <td style="text-align:center;"> DICE </td>
   <td style="text-align:center;"> 0.338 (0.262) </td>
   <td style="text-align:center;"> 0.236 (0.015) </td>
   <td style="text-align:center;"> 0.120 (0.001) </td>
   <td style="text-align:center;"> 0.061 (0.009) </td>
   <td style="text-align:center;"> 0.083 (0.024) </td>
   <td style="text-align:center;"> -0.222 (0.023) </td>
   <td style="text-align:center;"> 1.995 (0.129) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> overlapping </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> LogisticRegression </td>
   <td style="text-align:center;"> Generic </td>
   <td style="text-align:center;"> 0.332 (0.275) </td>
   <td style="text-align:center;"> 0.238 (0.017) </td>
   <td style="text-align:center;"> 0.120 (0.001) </td>
   <td style="text-align:center;"> 0.064 (0.006) </td>
   <td style="text-align:center;"> 0.094 (0.010) </td>
   <td style="text-align:center;"> -0.224 (0.029) </td>
   <td style="text-align:center;"> 2.011 (0.105) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> overlapping </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> LogisticRegression </td>
   <td style="text-align:center;"> Greedy </td>
   <td style="text-align:center;"> 0.419 (0.212) </td>
   <td style="text-align:center;"> 0.204 (0.013) </td>
   <td style="text-align:center;"> 0.109 (0.001) </td>
   <td style="text-align:center;"> 0.046 (0.003) </td>
   <td style="text-align:center;"> 0.068 (0.011) </td>
   <td style="text-align:center;"> -0.161 (0.017) </td>
   <td style="text-align:center;"> 1.532 (0.104) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> overlapping </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> LogisticRegression </td>
   <td style="text-align:center;"> REVISE </td>
   <td style="text-align:center;"> 0.717 (0.198) </td>
   <td style="text-align:center;"> 0.121 (0.006) </td>
   <td style="text-align:center;"> 0.120 (0.001) </td>
   <td style="text-align:center;"> 0.026 (0.002) </td>
   <td style="text-align:center;"> 0.030 (0.008) </td>
   <td style="text-align:center;"> -0.082 (0.007) </td>
   <td style="text-align:center;"> 1.305 (0.311) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> moons </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxEnsemble </td>
   <td style="text-align:center;"> DICE </td>
   <td style="text-align:center;"> 0.158 (0.031) </td>
   <td style="text-align:center;"> 0.005 (0.006) </td>
   <td style="text-align:center;"> 0.042 (0.004) </td>
   <td style="text-align:center;"> 0.014 (0.002) </td>
   <td style="text-align:center;"> -0.001 (0.000) </td>
   <td style="text-align:center;"> -0.008 (0.006) </td>
   <td style="text-align:center;"> 0.426 (0.102) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> moons </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxEnsemble </td>
   <td style="text-align:center;"> Generic </td>
   <td style="text-align:center;"> 0.146 (0.008) </td>
   <td style="text-align:center;"> 0.010 (0.005) </td>
   <td style="text-align:center;"> 0.039 (0.006) </td>
   <td style="text-align:center;"> 0.019 (0.017) </td>
   <td style="text-align:center;"> 0.000 (0.001) </td>
   <td style="text-align:center;"> -0.012 (0.006) </td>
   <td style="text-align:center;"> 0.414 (0.079) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> moons </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxEnsemble </td>
   <td style="text-align:center;"> Greedy </td>
   <td style="text-align:center;"> 0.243 (0.080) </td>
   <td style="text-align:center;"> 0.038 (0.010) </td>
   <td style="text-align:center;"> 0.132 (0.002) </td>
   <td style="text-align:center;"> 0.092 (0.010) </td>
   <td style="text-align:center;"> 0.003 (0.002) </td>
   <td style="text-align:center;"> -0.063 (0.011) </td>
   <td style="text-align:center;"> 0.246 (0.036) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> moons </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxEnsemble </td>
   <td style="text-align:center;"> REVISE </td>
   <td style="text-align:center;"> 0.127 (0.031) </td>
   <td style="text-align:center;"> 0.014 (0.005) </td>
   <td style="text-align:center;"> 0.117 (0.004) </td>
   <td style="text-align:center;"> 0.016 (0.004) </td>
   <td style="text-align:center;"> -0.001 (0.000) </td>
   <td style="text-align:center;"> -0.014 (0.004) </td>
   <td style="text-align:center;"> 0.383 (0.025) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> moons </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxModel </td>
   <td style="text-align:center;"> DICE </td>
   <td style="text-align:center;"> 0.050 (0.037) </td>
   <td style="text-align:center;"> 0.010 (0.004) </td>
   <td style="text-align:center;"> 0.034 (0.004) </td>
   <td style="text-align:center;"> 0.015 (0.011) </td>
   <td style="text-align:center;"> -0.001 (0.000) </td>
   <td style="text-align:center;"> -0.017 (0.010) </td>
   <td style="text-align:center;"> 0.353 (0.103) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> moons </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxModel </td>
   <td style="text-align:center;"> Generic </td>
   <td style="text-align:center;"> 0.027 (0.023) </td>
   <td style="text-align:center;"> 0.006 (0.004) </td>
   <td style="text-align:center;"> 0.032 (0.005) </td>
   <td style="text-align:center;"> 0.009 (0.006) </td>
   <td style="text-align:center;"> -0.000 (0.001) </td>
   <td style="text-align:center;"> -0.008 (0.003) </td>
   <td style="text-align:center;"> 0.392 (0.101) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> moons </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxModel </td>
   <td style="text-align:center;"> Greedy </td>
   <td style="text-align:center;"> 0.337 (0.070) </td>
   <td style="text-align:center;"> 0.033 (0.014) </td>
   <td style="text-align:center;"> 0.136 (0.004) </td>
   <td style="text-align:center;"> 0.085 (0.015) </td>
   <td style="text-align:center;"> 0.003 (0.004) </td>
   <td style="text-align:center;"> -0.057 (0.020) </td>
   <td style="text-align:center;"> 0.228 (0.030) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> moons </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxModel </td>
   <td style="text-align:center;"> REVISE </td>
   <td style="text-align:center;"> 0.041 (0.019) </td>
   <td style="text-align:center;"> 0.014 (0.005) </td>
   <td style="text-align:center;"> 0.112 (0.003) </td>
   <td style="text-align:center;"> 0.009 (0.006) </td>
   <td style="text-align:center;"> -0.000 (0.001) </td>
   <td style="text-align:center;"> -0.014 (0.002) </td>
   <td style="text-align:center;"> 0.368 (0.032) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> moons </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> LogisticRegression </td>
   <td style="text-align:center;"> DICE </td>
   <td style="text-align:center;"> 0.596 (0.065) </td>
   <td style="text-align:center;"> 0.173 (0.009) </td>
   <td style="text-align:center;"> 0.105 (0.001) </td>
   <td style="text-align:center;"> 0.055 (0.017) </td>
   <td style="text-align:center;"> 0.056 (0.018) </td>
   <td style="text-align:center;"> -0.092 (0.013) </td>
   <td style="text-align:center;"> 3.919 (0.287) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> moons </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> LogisticRegression </td>
   <td style="text-align:center;"> Generic </td>
   <td style="text-align:center;"> 0.594 (0.054) </td>
   <td style="text-align:center;"> 0.173 (0.007) </td>
   <td style="text-align:center;"> 0.105 (0.001) </td>
   <td style="text-align:center;"> 0.054 (0.012) </td>
   <td style="text-align:center;"> 0.050 (0.012) </td>
   <td style="text-align:center;"> -0.096 (0.008) </td>
   <td style="text-align:center;"> 3.912 (0.325) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> moons </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> LogisticRegression </td>
   <td style="text-align:center;"> Greedy </td>
   <td style="text-align:center;"> 0.542 (0.082) </td>
   <td style="text-align:center;"> 0.176 (0.013) </td>
   <td style="text-align:center;"> 0.101 (0.001) </td>
   <td style="text-align:center;"> 0.057 (0.016) </td>
   <td style="text-align:center;"> 0.053 (0.011) </td>
   <td style="text-align:center;"> -0.095 (0.017) </td>
   <td style="text-align:center;"> 3.760 (0.261) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> moons </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> LogisticRegression </td>
   <td style="text-align:center;"> REVISE </td>
   <td style="text-align:center;"> 0.687 (0.066) </td>
   <td style="text-align:center;"> 0.124 (0.006) </td>
   <td style="text-align:center;"> 0.140 (0.002) </td>
   <td style="text-align:center;"> 0.037 (0.011) </td>
   <td style="text-align:center;"> 0.031 (0.005) </td>
   <td style="text-align:center;"> -0.031 (0.008) </td>
   <td style="text-align:center;"> 3.503 (0.292) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> linearly_separable </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxEnsemble </td>
   <td style="text-align:center;"> DICE </td>
   <td style="text-align:center;"> 0.000 (0.000) </td>
   <td style="text-align:center;"> 0.000 (0.000) </td>
   <td style="text-align:center;"> 0.031 (0.017) </td>
   <td style="text-align:center;"> 0.020 (0.009) </td>
   <td style="text-align:center;"> -0.000 (0.001) </td>
   <td style="text-align:center;"> 0.000 (0.000) </td>
   <td style="text-align:center;"> 0.076 (0.024) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> linearly_separable </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxEnsemble </td>
   <td style="text-align:center;"> Generic </td>
   <td style="text-align:center;"> 0.000 (0.000) </td>
   <td style="text-align:center;"> 0.000 (0.000) </td>
   <td style="text-align:center;"> 0.036 (0.013) </td>
   <td style="text-align:center;"> 0.021 (0.018) </td>
   <td style="text-align:center;"> -0.001 (0.000) </td>
   <td style="text-align:center;"> 0.000 (0.000) </td>
   <td style="text-align:center;"> 0.083 (0.028) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> linearly_separable </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxEnsemble </td>
   <td style="text-align:center;"> Greedy </td>
   <td style="text-align:center;"> 0.051 (0.036) </td>
   <td style="text-align:center;"> 0.010 (0.009) </td>
   <td style="text-align:center;"> 0.337 (0.001) </td>
   <td style="text-align:center;"> 0.023 (0.013) </td>
   <td style="text-align:center;"> -0.000 (0.001) </td>
   <td style="text-align:center;"> -0.003 (0.004) </td>
   <td style="text-align:center;"> 0.351 (0.064) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> linearly_separable </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxEnsemble </td>
   <td style="text-align:center;"> REVISE </td>
   <td style="text-align:center;"> 0.000 (0.000) </td>
   <td style="text-align:center;"> 0.000 (0.000) </td>
   <td style="text-align:center;"> 0.129 (0.001) </td>
   <td style="text-align:center;"> 0.028 (0.008) </td>
   <td style="text-align:center;"> -0.001 (0.000) </td>
   <td style="text-align:center;"> 0.000 (0.000) </td>
   <td style="text-align:center;"> 0.072 (0.004) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> linearly_separable </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxModel </td>
   <td style="text-align:center;"> DICE </td>
   <td style="text-align:center;"> 0.000 (0.000) </td>
   <td style="text-align:center;"> 0.000 (0.000) </td>
   <td style="text-align:center;"> 0.028 (0.012) </td>
   <td style="text-align:center;"> 0.021 (0.015) </td>
   <td style="text-align:center;"> -0.001 (0.000) </td>
   <td style="text-align:center;"> 0.000 (0.000) </td>
   <td style="text-align:center;"> 0.073 (0.024) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> linearly_separable </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxModel </td>
   <td style="text-align:center;"> Generic </td>
   <td style="text-align:center;"> 0.000 (0.000) </td>
   <td style="text-align:center;"> 0.000 (0.000) </td>
   <td style="text-align:center;"> 0.019 (0.008) </td>
   <td style="text-align:center;"> 0.023 (0.006) </td>
   <td style="text-align:center;"> -0.000 (0.001) </td>
   <td style="text-align:center;"> 0.000 (0.000) </td>
   <td style="text-align:center;"> 0.078 (0.006) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> linearly_separable </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxModel </td>
   <td style="text-align:center;"> Greedy </td>
   <td style="text-align:center;"> 0.066 (0.065) </td>
   <td style="text-align:center;"> 0.011 (0.009) </td>
   <td style="text-align:center;"> 0.337 (0.002) </td>
   <td style="text-align:center;"> 0.019 (0.013) </td>
   <td style="text-align:center;"> -0.001 (0.000) </td>
   <td style="text-align:center;"> -0.004 (0.005) </td>
   <td style="text-align:center;"> 0.343 (0.051) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> linearly_separable </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxModel </td>
   <td style="text-align:center;"> REVISE </td>
   <td style="text-align:center;"> 0.000 (0.000) </td>
   <td style="text-align:center;"> 0.000 (0.000) </td>
   <td style="text-align:center;"> 0.129 (0.001) </td>
   <td style="text-align:center;"> 0.027 (0.004) </td>
   <td style="text-align:center;"> -0.001 (0.000) </td>
   <td style="text-align:center;"> 0.000 (0.000) </td>
   <td style="text-align:center;"> 0.074 (0.002) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> linearly_separable </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> LogisticRegression </td>
   <td style="text-align:center;"> DICE </td>
   <td style="text-align:center;"> 0.041 (0.027) </td>
   <td style="text-align:center;"> 0.000 (0.000) </td>
   <td style="text-align:center;"> 0.127 (0.010) </td>
   <td style="text-align:center;"> 0.083 (0.007) </td>
   <td style="text-align:center;"> -0.000 (0.000) </td>
   <td style="text-align:center;"> 0.000 (0.000) </td>
   <td style="text-align:center;"> 9.771 (0.638) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> linearly_separable </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> LogisticRegression </td>
   <td style="text-align:center;"> Generic </td>
   <td style="text-align:center;"> 0.051 (0.048) </td>
   <td style="text-align:center;"> 0.001 (0.001) </td>
   <td style="text-align:center;"> 0.133 (0.014) </td>
   <td style="text-align:center;"> 0.094 (0.011) </td>
   <td style="text-align:center;"> -0.000 (0.001) </td>
   <td style="text-align:center;"> 0.000 (0.000) </td>
   <td style="text-align:center;"> 9.420 (0.803) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> linearly_separable </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> LogisticRegression </td>
   <td style="text-align:center;"> Greedy </td>
   <td style="text-align:center;"> 1.372 (0.180) </td>
   <td style="text-align:center;"> 0.057 (0.017) </td>
   <td style="text-align:center;"> 0.333 (0.001) </td>
   <td style="text-align:center;"> 0.066 (0.005) </td>
   <td style="text-align:center;"> 0.015 (0.007) </td>
   <td style="text-align:center;"> -0.043 (0.017) </td>
   <td style="text-align:center;"> 11.720 (0.702) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> linearly_separable </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> LogisticRegression </td>
   <td style="text-align:center;"> REVISE </td>
   <td style="text-align:center;"> 0.002 (0.002) </td>
   <td style="text-align:center;"> 0.000 (0.000) </td>
   <td style="text-align:center;"> 0.128 (0.001) </td>
   <td style="text-align:center;"> 0.031 (0.017) </td>
   <td style="text-align:center;"> -0.001 (0.000) </td>
   <td style="text-align:center;"> 0.000 (0.000) </td>
   <td style="text-align:center;"> 5.783 (4.842) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> circles </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxEnsemble </td>
   <td style="text-align:center;"> DICE </td>
   <td style="text-align:center;"> 0.039 (0.026) </td>
   <td style="text-align:center;"> 0.006 (0.003) </td>
   <td style="text-align:center;"> 0.008 (0.002) </td>
   <td style="text-align:center;"> 0.002 (0.000) </td>
   <td style="text-align:center;"> -0.000 (0.000) </td>
   <td style="text-align:center;"> -0.006 (0.002) </td>
   <td style="text-align:center;"> 0.390 (0.087) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> circles </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxEnsemble </td>
   <td style="text-align:center;"> Generic </td>
   <td style="text-align:center;"> 0.046 (0.029) </td>
   <td style="text-align:center;"> 0.009 (0.005) </td>
   <td style="text-align:center;"> 0.008 (0.002) </td>
   <td style="text-align:center;"> 0.002 (0.001) </td>
   <td style="text-align:center;"> -0.000 (0.001) </td>
   <td style="text-align:center;"> -0.008 (0.008) </td>
   <td style="text-align:center;"> 0.382 (0.077) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> circles </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxEnsemble </td>
   <td style="text-align:center;"> Greedy </td>
   <td style="text-align:center;"> 0.196 (0.186) </td>
   <td style="text-align:center;"> 0.040 (0.016) </td>
   <td style="text-align:center;"> 0.050 (0.001) </td>
   <td style="text-align:center;"> 0.007 (0.001) </td>
   <td style="text-align:center;"> 0.003 (0.005) </td>
   <td style="text-align:center;"> -0.043 (0.013) </td>
   <td style="text-align:center;"> 0.332 (0.064) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> circles </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxEnsemble </td>
   <td style="text-align:center;"> REVISE </td>
   <td style="text-align:center;"> 0.042 (0.015) </td>
   <td style="text-align:center;"> 0.003 (0.003) </td>
   <td style="text-align:center;"> -0.003 (0.000) </td>
   <td style="text-align:center;"> 0.001 (0.000) </td>
   <td style="text-align:center;"> -0.000 (0.001) </td>
   <td style="text-align:center;"> -0.003 (0.003) </td>
   <td style="text-align:center;"> 0.363 (0.075) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> circles </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxModel </td>
   <td style="text-align:center;"> DICE </td>
   <td style="text-align:center;"> 0.030 (0.010) </td>
   <td style="text-align:center;"> 0.010 (0.006) </td>
   <td style="text-align:center;"> 0.007 (0.002) </td>
   <td style="text-align:center;"> 0.003 (0.001) </td>
   <td style="text-align:center;"> -0.000 (0.001) </td>
   <td style="text-align:center;"> -0.010 (0.010) </td>
   <td style="text-align:center;"> 0.350 (0.065) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> circles </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxModel </td>
   <td style="text-align:center;"> Generic </td>
   <td style="text-align:center;"> 0.031 (0.021) </td>
   <td style="text-align:center;"> 0.010 (0.006) </td>
   <td style="text-align:center;"> 0.008 (0.000) </td>
   <td style="text-align:center;"> 0.003 (0.001) </td>
   <td style="text-align:center;"> 0.001 (0.001) </td>
   <td style="text-align:center;"> -0.011 (0.008) </td>
   <td style="text-align:center;"> 0.361 (0.079) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> circles </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxModel </td>
   <td style="text-align:center;"> Greedy </td>
   <td style="text-align:center;"> 0.200 (0.177) </td>
   <td style="text-align:center;"> 0.037 (0.016) </td>
   <td style="text-align:center;"> 0.053 (0.003) </td>
   <td style="text-align:center;"> 0.008 (0.002) </td>
   <td style="text-align:center;"> 0.003 (0.006) </td>
   <td style="text-align:center;"> -0.051 (0.016) </td>
   <td style="text-align:center;"> 0.301 (0.064) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> circles </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> FluxModel </td>
   <td style="text-align:center;"> REVISE </td>
   <td style="text-align:center;"> 0.049 (0.019) </td>
   <td style="text-align:center;"> 0.003 (0.004) </td>
   <td style="text-align:center;"> -0.003 (0.000) </td>
   <td style="text-align:center;"> 0.001 (0.000) </td>
   <td style="text-align:center;"> -0.001 (0.000) </td>
   <td style="text-align:center;"> -0.003 (0.003) </td>
   <td style="text-align:center;"> 0.366 (0.075) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> circles </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> LogisticRegression </td>
   <td style="text-align:center;"> DICE </td>
   <td style="text-align:center;"> 17.221 (0.090) </td>
   <td style="text-align:center;"> 0.425 (0.000) </td>
   <td style="text-align:center;"> 0.158 (0.001) </td>
   <td style="text-align:center;"> 0.670 (0.007) </td>
   <td style="text-align:center;"> 0.666 (0.005) </td>
   <td style="text-align:center;"> -0.078 (0.000) </td>
   <td style="text-align:center;"> 1.550 (0.037) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> circles </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> LogisticRegression </td>
   <td style="text-align:center;"> Generic </td>
   <td style="text-align:center;"> 17.231 (0.087) </td>
   <td style="text-align:center;"> 0.425 (0.000) </td>
   <td style="text-align:center;"> 0.158 (0.001) </td>
   <td style="text-align:center;"> 0.670 (0.006) </td>
   <td style="text-align:center;"> 0.667 (0.006) </td>
   <td style="text-align:center;"> -0.078 (0.000) </td>
   <td style="text-align:center;"> 1.551 (0.038) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> circles </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> LogisticRegression </td>
   <td style="text-align:center;"> Greedy </td>
   <td style="text-align:center;"> 17.537 (0.021) </td>
   <td style="text-align:center;"> 0.425 (0.000) </td>
   <td style="text-align:center;"> 0.161 (0.002) </td>
   <td style="text-align:center;"> 0.689 (0.002) </td>
   <td style="text-align:center;"> 0.687 (0.002) </td>
   <td style="text-align:center;"> -0.078 (0.000) </td>
   <td style="text-align:center;"> 1.663 (0.019) </td>
  </tr>
  <tr>
   <td style="text-align:center;width: 5em; font-weight: bold;"> circles </td>
   <td style="text-align:center;"> 50 </td>
   <td style="text-align:center;"> LogisticRegression </td>
   <td style="text-align:center;"> REVISE </td>
   <td style="text-align:center;"> 17.536 (0.021) </td>
   <td style="text-align:center;"> 0.425 (0.000) </td>
   <td style="text-align:center;"> 0.204 (0.000) </td>
   <td style="text-align:center;"> 0.688 (0.004) </td>
   <td style="text-align:center;"> 0.685 (0.002) </td>
   <td style="text-align:center;"> -0.078 (0.000) </td>
   <td style="text-align:center;"> 1.653 (0.019) </td>
  </tr>
</tbody>
</table>

Chart in paper

Figure 7 shows the chart that went into the paper.

Code
using DataFrames, Statistics
df = results[:overlapping].output
df = df[[x  [50] for x in df.n],:]
gdf = groupby(df, [:generator, :model, :n, :name, :scope])
df_plot = combine(gdf, :value => (x -> [(mean(x),mean(x)+std(x),mean(x)-std(x))]) => [:mean, :ymax, :ymin])
df_plot = df_plot[[name in [:decisiveness, :disagreement, :mmd, :mmd_grid, :model_performance] for name in df_plot.name],:]
df_plot = df_plot[.!(df_plot.name.==:mmd .&& df_plot.scope.==:model),:]
df_plot = mapcols(x -> typeof(x) == Vector{Symbol} ? string.(x) : x, df_plot)
transform!(df_plot, :name => (X -> [x=="decisiveness" ? "Decisiveness" : x for x in X]) => :name)
transform!(df_plot, :name => (X -> [x=="disagreement" ? "Disagreement" : x for x in X]) => :name)
transform!(df_plot, :name => (X -> [x=="mmd" ? "MMD (domain)" : x for x in X]) => :name)
transform!(df_plot, :name => (X -> [x=="mmd_grid" ? "MMD (model)" : x for x in X]) => :name)
transform!(df_plot, :name => (X -> [x=="model_performance" ? "Performance" : x for x in X]) => :name)
transform!(df_plot, :generator => (X -> [x=="REVISE" ? "Latent" : x for x in X]) => :generator)
transform!(df_plot, :model => (X -> [x=="FluxEnsemble" ? "Deep Ensemble" : x for x in X]) => :model)
transform!(df_plot, :model => (X -> [x=="FluxModel" ? "MLP" : x for x in X]) => :model)
transform!(df_plot, :model => (X -> [x=="LogisticRegression" ? "Linear" : x for x in X]) => :model)

ncol = length(unique(df_plot.model))
nrow = length(unique(df_plot.name))

using RCall
scale_ = 1.5
R"""
library(data.table)
df_plot <- data.table($df_plot)
name_order <- c(
    "MMD (domain)",
    "MMD (model)",
    "Performance",
    "Disagreement",
    "Decisiveness"
)
df_plot[,name:=factor(name, levels=name_order)]
model_order <- c("Linear", "MLP", "Deep Ensemble")
df_plot[,model:=factor(model, levels=model_order)]
library(ggplot2)
plt <- ggplot(df_plot) +
    geom_bar(aes(x=n, y=mean, fill=generator), stat="identity", alpha=0.5, position="dodge") +
    geom_pointrange(aes(x=n, y=mean, ymin=ymin, ymax=ymax, colour=generator), alpha=0.9, position=position_dodge(width=c(0.9,0.9)), size=0.5) +
    facet_grid(
        rows = vars(name),
        cols =  vars(model), 
        scales = "free_y"
    ) +
    labs(y = "Value") + 
    scale_fill_discrete(name="Generator:") +
    scale_colour_discrete(name="Generator:") +
    theme(
        axis.title.x=element_blank(),
        axis.text.x=element_blank(),
        axis.ticks.x=element_blank(),
        legend.position="bottom"
    )
temp_path <- file.path(tempdir(), "plot.png")
ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.75) 
"""

img = Images.load(rcopy(R"temp_path"))
Images.save(joinpath(www_path,"paper_synthetic_results.png"), img)
Code
Images.load(joinpath(www_path,"paper_synthetic_results.png"))

Real-World Data

Code
using Pkg; Pkg.activate("dev")
Code
include("dev/utils.jl")
using AlgorithmicRecourseDynamics
using CounterfactualExplanations, Flux, Plots, PlotThemes, Random, LaplaceRedux, LinearAlgebra
theme(:wong)
output_path = output_dir("real_world")
www_path = www_dir("real_world");
Code
max_obs = 2500
data_sets = AlgorithmicRecourseDynamics.Data.load_real_world(max_obs)
choices = [
    :cal_housing, 
    :credit_default, 
    :gmsc, 
]
data_sets = filter(p -> p[1] in choices, data_sets)
Code
using CounterfactualExplanations.DataPreprocessing: unpack
bs = 50
function data_loader(data::CounterfactualData)
    X, y = unpack(data)
    data = Flux.DataLoader((X,y),batchsize=bs)
    return data
end
model_params = (batch_norm=false,n_hidden=32,n_layers=3,dropout=true,p_dropout=0.25)
Code
models = [
    :LogisticRegression, 
    :FluxModel, 
    :FluxEnsemble
]
generators = Dict(
    :Greedy=>GreedyGenerator(), 
    :Generic=>GenericGenerator(),
    :REVISE=>REVISEGenerator(),
    :DICE=>DiCEGenerator(),
)
Code
experiments = set_up_experiments(
    data_sets,models,generators; 
    pre_train_models=100, model_params=model_params, 
    data_loader=data_loader
)

Running Experiment

Code
n_evals = 5
n_rounds = 50
evaluate_every = Int(round(n_rounds/n_evals))
n_folds = 5
n_bootstrap = 1
n_samples = 10000
T = 250
generative_model_params = (epochs=250, latent_dim=8)
using Serialization
results = run_experiments(
    experiments;
    save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, n_bootstrap=n_bootstrap, T=T, n_samples=n_samples,
    generative_model_params=generative_model_params
)
Serialization.serialize(joinpath(output_path,"results.jls"),results)

Plots

Code
using Serialization
results = Serialization.deserialize(joinpath(output_path,"results.jls"))
Code
using Images
line_charts = Dict()
errorbar_charts = Dict()
for (data_name, res) in results
    plt = plot(res)
    Images.save(joinpath(www_path, "line_chart_$(data_name).png"), plt)
    line_charts[data_name] = plt
    plt = plot(res,maximum(res.output.n))
    Images.save(joinpath(www_path, "errorbar_chart_$(data_name).png"), plt)
    errorbar_charts[data_name] = plt
end

Line Charts

Figure 8 shows the evolution of the evaluation metrics over the course of the experiment.

Code
img_files = readdir(www_path)[contains.(readdir(www_path),"line_chart")]
img_files = joinpath.(www_path,img_files)
for img in img_files
    display(load(img))
end

(a) California Housing

(b) Credit Default

(c) GMSC

Error Bar Charts

Figure 9 shows the evaluation metrics at the end of the experiments.

Code
img_files = readdir(www_path)[contains.(readdir(www_path),"errorbar_chart")]
img_files = joinpath.(www_path,img_files)
for img in img_files
    display(load(img))
end

(a) California Housing

(b) Credit Default

(c) GMSC

Chart in paper

Figure 10 shows the chart that went into the paper.

Code
using DataFrames, Statistics
model_ = :FluxEnsemble
df = DataFrame() 
for (key, val) in results
    df_ = deepcopy(val.output)
    df_.dataset .= key
    df = vcat(df,df_)
end
df = df[df.n .== maximum(df.n),:]
df = df[df.model .== model_,:]
filter!(:value => x -> !any(f -> f(x), (ismissing, isnothing, isnan)), df)
gdf = groupby(df, [:generator, :dataset, :n, :name, :scope])
df_plot = combine(gdf, :value => (x -> [(mean(x),mean(x)+std(x),mean(x)-std(x))]) => [:mean, :ymax, :ymin])
df_plot = df_plot[[name in [:mmd, :model_performance] for name in df_plot.name],:]
df_plot = mapcols(x -> typeof(x) == Vector{Symbol} ? string.(x) : x, df_plot)
df_plot.name .= [r[:name] == "mmd" ? "$(r[:name])_$(r[:scope])" : r[:name] for r in eachrow(df_plot)]
transform!(df_plot, :dataset => (X -> [x=="cal_housing" ? "California Housing" : x for x in X]) => :dataset)
transform!(df_plot, :dataset => (X -> [x=="credit_default" ? "Credit Default" : x for x in X]) => :dataset)
transform!(df_plot, :dataset => (X -> [x=="gmsc" ? "GMSC" : x for x in X]) => :dataset)
transform!(df_plot, :name => (X -> [x=="mmd_domain" ? "MMD (domain)" : x for x in X]) => :name)
transform!(df_plot, :name => (X -> [x=="mmd_model" ? "MMD (model)" : x for x in X]) => :name)
transform!(df_plot, :name => (X -> [x=="model_performance" ? "Performance" : x for x in X]) => :name)
transform!(df_plot, :generator => (X -> [x=="REVISE" ? "Latent" : x for x in X]) => :generator)

ncol = length(unique(df_plot.dataset))
nrow = length(unique(df_plot.name))

using RCall
scale_ = 1.75
R"""
library(ggplot2)
plt <- ggplot($df_plot) +
    geom_bar(aes(x=n, y=mean, fill=generator), stat="identity", alpha=0.5, position="dodge") +
    geom_pointrange( aes(x=n, y=mean, ymin=ymin, ymax=ymax, colour=generator), alpha=0.9, position=position_dodge(width=0.9), size=0.5) +
    facet_grid(
        rows = vars(name),
        cols =  vars(dataset), 
        scales = "free_y"
    ) +
    labs(y = "Value") + 
    scale_fill_discrete(name="Generator:") +
    scale_colour_discrete(name="Generator:") +
    theme(
        axis.title.x=element_blank(),
        axis.text.x=element_blank(),
        axis.ticks.x=element_blank(),
        legend.position="bottom"
    )
temp_path <- file.path(tempdir(), "plot.png")
ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.8) 
"""

img = Images.load(rcopy(R"temp_path"))
Images.save(joinpath(www_path,"paper_real_world_results.png"), img)
Code
Images.load(joinpath(www_path,"paper_real_world_results.png"))

Mitigation Strategies

Code
using Pkg; Pkg.activate("dev")
Code
include("dev/utils.jl")
using AlgorithmicRecourseDynamics
using CounterfactualExplanations, Flux, Plots, PlotThemes, Random, LaplaceRedux, LinearAlgebra
theme(:wong)
output_path = output_dir("mitigation_strategies")
www_path = www_dir("mitigation_strategies")
Code
models = [
    :LogisticRegression, 
    :FluxModel, 
    :FluxEnsemble,
]
generators = Dict(
    :Generic=>GenericGenerator(decision_threshold=0.5),
    :Latent=>REVISEGenerator(),
    :Generic_conservative=>GenericGenerator(decision_threshold=0.9),
    :Gravitational=>GravitationalGenerator(),
    :ClapROAR=>ClapROARGenerator()
)

Synthetic

Code
max_obs = 1000
catalogue = AlgorithmicRecourseDynamics.Data.load_synthetic(max_obs)
choices = [
    :linearly_separable, 
    :overlapping, 
    :circles, 
    :moons,
]
data_sets = filter(p -> p[1] in choices, catalogue)
Code
experiments = set_up_experiments(data_sets,models,generators)
Code
using AlgorithmicRecourseDynamics.Models: model_evaluation
plts = []
for (exp_name, exp_) in experiments
    for (M_name, M) in exp_.models
        score = round(model_evaluation(M, exp_.test_data),digits=2)
        plt = plot(M, exp_.test_data, title="$exp_name;\n $M_name ($score)")
        # Errors:
        ids = findall(vec(round.(probs(M, exp_.test_data.X)) .!= exp_.test_data.y))
        x_wrongly_labelled = exp_.test_data.X[:,ids]
        scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
        plts = vcat(plts..., plt)
    end
end
plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
savefig(plt, joinpath(www_path,"models_test_before.png"))
Code
using AlgorithmicRecourseDynamics.Models: model_evaluation
plts = []
for (exp_name, exp_) in experiments
    for (M_name, M) in exp_.models
        score = round(model_evaluation(M, exp_.train_data),digits=2)
        plt = plot(M, exp_.train_data, title="$exp_name;\n $M_name ($score)")
        # Errors:
        ids = findall(vec(round.(probs(M, exp_.train_data.X)) .!= exp_.train_data.y))
        x_wrongly_labelled = exp_.train_data.X[:,ids]
        scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
        plts = vcat(plts..., plt)
    end
end
plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
savefig(plt, joinpath(www_path,"models_train_before.png"))
Code
n_evals = 5
n_rounds = 50
evaluate_every = Int(round(n_rounds/n_evals))
n_folds = 5
n_bootstrap = 1
T = 100
using Serialization
results = run_experiments(
    experiments;
    save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, n_bootstrap=n_bootstrap, T=T
)
Serialization.serialize(joinpath(output_path,"results_synthetic.jls"),results)
Code
using AlgorithmicRecourseDynamics.Models: model_evaluation
plot_dict = Dict(key => Dict() for (key,val) in results)
fold = 1
for (name, res) in results
    exp_ = res.experiment
    plot_dict[name] = Dict(key => [] for (key,val) in exp_.generators)
    rec_sys = exp_.recourse_systems[fold]
    sys_ids = collect(exp_.system_identifiers)
    M = length(rec_sys)
    for m in 1:M
        model_name, generator_name = sys_ids[m]
        M = rec_sys[m].model
        score = round(model_evaluation(M, exp_.test_data),digits=2)
        plt = plot(M, exp_.test_data, title="$name;\n $model_name ($score)")
        # Errors:
        ids = findall(vec(round.(probs(M, exp_.test_data.X)) .!= exp_.test_data.y))
        x_wrongly_labelled = exp_.test_data.X[:,ids]
        scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
        plot_dict[name][generator_name] = vcat(plot_dict[name][generator_name], plt)
    end
end
plot_dict = Dict(key => reduce(vcat, [plots[key] for plots in values(plot_dict)]) for (key, value) in generators)
for (name, plts) in plot_dict
    plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
    savefig(plt, joinpath(www_path,"models_test_after_$(name).png"))
end
Code
using AlgorithmicRecourseDynamics.Models: model_evaluation
plot_dict = Dict(key => Dict() for (key,val) in results)
fold = 1
for (name, res) in results
    exp_ = res.experiment
    plot_dict[name] = Dict(key => [] for (key,val) in exp_.generators)
    rec_sys = exp_.recourse_systems[fold]
    sys_ids = collect(exp_.system_identifiers)
    M = length(rec_sys)
    for m in 1:M
        model_name, generator_name = sys_ids[m]
        M = rec_sys[m].model
        data = rec_sys[m].data
        score = round(model_evaluation(M, data),digits=2)
        plt = plot(M, data, title="$name;\n $model_name ($score)")
        # Errors:
        ids = findall(vec(round.(probs(M, data.X)) .!= data.y))
        x_wrongly_labelled = data.X[:,ids]
        scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
        plot_dict[name][generator_name] = vcat(plot_dict[name][generator_name], plt)
    end
end
plot_dict = Dict(key => reduce(vcat, [plots[key] for plots in values(plot_dict)]) for (key, value) in generators)
for (name, plts) in plot_dict
    plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
    savefig(plt, joinpath(www_path,"models_train_after_$(name).png"))
end
Code
using Serialization
results = Serialization.deserialize(joinpath(output_path,"results_synthetic.jls"))
Code
using Images
line_charts = Dict()
errorbar_charts = Dict()
for (data_name, res) in results
    plt = plot(res)
    Images.save(joinpath(www_path, "line_chart_$(data_name).png"), plt)
    line_charts[data_name] = plt
    plt = plot(res,maximum(res.output.n))
    Images.save(joinpath(www_path, "errorbar_chart_$(data_name).png"), plt)
    errorbar_charts[data_name] = plt
end

Chart in paper

Code
using DataFrames, Statistics
df = results[:overlapping].output
df = df[df.n .== maximum(df.n),:]
gdf = groupby(df, [:generator, :model, :n, :name, :scope])
df_plot = combine(gdf, :value => (x -> [(mean(x),mean(x)+std(x),mean(x)-std(x))]) => [:mean, :ymax, :ymin])
df_plot = df_plot[[name in [:mmd, :mmd_grid, :model_performance] for name in df_plot.name],:]
df_plot = df_plot[.!(df_plot.name.==:mmd .&& df_plot.scope.==:model),:]
df_plot = mapcols(x -> typeof(x) == Vector{Symbol} ? string.(x) : x, df_plot)
transform!(df_plot, :name => (X -> [x=="mmd" ? "MMD (domain)" : x for x in X]) => :name)
transform!(df_plot, :name => (X -> [x=="mmd_grid" ? "MMD (model)" : x for x in X]) => :name)
transform!(df_plot, :name => (X -> [x=="model_performance" ? "Performance" : x for x in X]) => :name)
transform!(df_plot, :generator => (X -> [x=="Generic" ? "Generic (γ=0.5)" : x for x in X]) => :generator)
transform!(df_plot, :generator => (X -> [x=="Generic_conservative" ? "Generic (γ=0.9)" : x for x in X]) => :generator)
transform!(df_plot, :model => (X -> [x=="FluxEnsemble" ? "Deep Ensemble" : x for x in X]) => :model)
transform!(df_plot, :model => (X -> [x=="FluxModel" ? "MLP" : x for x in X]) => :model)
transform!(df_plot, :model => (X -> [x=="LogisticRegression" ? "Linear" : x for x in X]) => :model)

ncol = length(unique(df_plot.model))
nrow = length(unique(df_plot.name))

using RCall
scale_ = 2.0
R"""
library(data.table)
df_plot <- data.table($df_plot)
model_order <- c("Linear", "MLP", "Deep Ensemble")
df_plot[,model:=factor(model, levels=model_order)]
library(ggplot2)
plt <- ggplot($df_plot) +
    geom_bar(aes(x=n, y=mean, fill=generator), stat="identity", alpha=0.5, position="dodge") +
    geom_pointrange( aes(x=n, y=mean, ymin=ymin, ymax=ymax, colour=generator), alpha=0.9, position=position_dodge(width=0.9), size=0.5) +
    facet_grid(
        rows = vars(name),
        cols =  vars(model), 
        scales = "free_y"
    ) +
    labs(y = "Value") + 
    scale_fill_discrete(name="Generator:") +
    scale_colour_discrete(name="Generator:") +
    theme(
        axis.title.x=element_blank(),
        axis.text.x=element_blank(),
        axis.ticks.x=element_blank(),
        legend.position="bottom"
    ) +
    guides(fill=guide_legend(ncol=3))
temp_path <- file.path(tempdir(), "plot.png")
ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.8) 
"""

img = Images.load(rcopy(R"temp_path"))
Images.save(joinpath(www_path,"paper_synthetic_results.png"), img)

Chart in paper

Code
using DataFrames, Statistics
df = results[:overlapping].output
df = df[df.n .== maximum(df.n),:]
gdf = groupby(df, [:generator, :model, :n, :name, :scope])
df_plot = combine(gdf, :value => (x -> [(mean(x),mean(x)+std(x),mean(x)-std(x))]) => [:mean, :ymax, :ymin])
df_plot = df_plot[[name in [:mmd, :mmd_grid, :model_performance] for name in df_plot.name],:]
df_plot = df_plot[.!(df_plot.name.==:mmd .&& df_plot.scope.==:model),:]
df_plot = mapcols(x -> typeof(x) == Vector{Symbol} ? string.(x) : x, df_plot)
transform!(df_plot, :name => (X -> [x=="mmd" ? "MMD (domain)" : x for x in X]) => :name)
transform!(df_plot, :name => (X -> [x=="mmd_grid" ? "MMD (model)" : x for x in X]) => :name)
transform!(df_plot, :name => (X -> [x=="model_performance" ? "Performance" : x for x in X]) => :name)
transform!(df_plot, :generator => (X -> [x=="Latent" ? "Latent (γ=0.5)" : x for x in X]) => :generator)
transform!(df_plot, :generator => (X -> [x=="Latent_conservative" ? "Latent (γ=0.9)" : x for x in X]) => :generator)
transform!(df_plot, :model => (X -> [x=="FluxEnsemble" ? "Deep Ensemble" : x for x in X]) => :model)
transform!(df_plot, :model => (X -> [x=="FluxModel" ? "MLP" : x for x in X]) => :model)
transform!(df_plot, :model => (X -> [x=="LogisticRegression" ? "Linear" : x for x in X]) => :model)

ncol = length(unique(df_plot.model))
nrow = length(unique(df_plot.name))

using RCall
scale_ = 1.9
R"""
library(data.table)
df_plot <- data.table($df_plot)
model_order <- c("Linear", "MLP", "Deep Ensemble")
df_plot[,model:=factor(model, levels=model_order)]
library(ggplot2)
plt <- ggplot($df_plot) +
    geom_bar(aes(x=n, y=mean, fill=generator), stat="identity", alpha=0.5, position="dodge") +
    geom_pointrange( aes(x=n, y=mean, ymin=ymin, ymax=ymax, colour=generator), alpha=0.9, position=position_dodge(width=0.9), size=0.5) +
    facet_grid(
        rows = vars(name),
        cols =  vars(model), 
        scales = "free_y"
    ) +
    labs(y = "Value") + 
    scale_fill_discrete(name="Generator:") +
    scale_colour_discrete(name="Generator:") +
    theme(
        axis.title.x=element_blank(),
        axis.text.x=element_blank(),
        axis.ticks.x=element_blank(),
        legend.position="bottom"
    ) +
    guides(fill=guide_legend(ncol=4))
temp_path <- file.path(tempdir(), "plot.png")
ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.8) 
"""

img = Images.load(rcopy(R"temp_path"))
Images.save(joinpath(www_path,"paper_synthetic_latent_results.png"), img)

Real World

Code
generators = Dict(
    :Generic=>GenericGenerator(decision_threshold=0.5),
    :Latent=>REVISEGenerator(),
    :Generic_conservative=>GenericGenerator(decision_threshold=0.9),
    :Gravitational=>GravitationalGenerator(),
    :ClapROAR=>ClapROARGenerator()
)
Code
max_obs = 2500
data_sets = AlgorithmicRecourseDynamics.Data.load_real_world(max_obs)
Code
using CounterfactualExplanations.DataPreprocessing: unpack
bs = 50
function data_loader(data::CounterfactualData)
    X, y = unpack(data)
    data = Flux.DataLoader((X,y),batchsize=bs)
    return data
end
model_params = (batch_norm=false,n_hidden=32,n_layers=3,dropout=true,p_dropout=0.25)
Code
experiments = set_up_experiments(
    data_sets,models,generators; 
    pre_train_models=100, model_params=model_params, 
    data_loader=data_loader
)
Code
n_evals = 5
n_rounds = 50
evaluate_every = Int(round(n_rounds/n_evals))
n_folds = 5
n_bootstrap = 1
n_samples = 10000
T = 250
using Serialization
results = run_experiments(
    experiments;
    save_path=output_path,
    evaluate_every=evaluate_every,
    n_rounds=n_rounds, 
    n_folds=n_folds, 
    n_bootstrap=n_bootstrap, 
    T=T
)
Serialization.serialize(joinpath(output_path,"results_real_world.jls"),results)
Code
using Serialization
results = Serialization.deserialize(joinpath(output_path,"results_real_world.jls"))
Code
using Images
line_charts = Dict()
errorbar_charts = Dict()
for (data_name, res) in results
    plt = plot(res)
    Images.save(joinpath(www_path, "line_chart_$(data_name).png"), plt)
    line_charts[data_name] = plt
    plt = plot(res,maximum(res.output.n))
    Images.save(joinpath(www_path, "errorbar_chart_$(data_name).png"), plt)
    errorbar_charts[data_name] = plt
end

Chart in paper

Code
using DataFrames, Statistics
model_ = :FluxEnsemble
df = DataFrame() 
for (key, val) in results
    df_ = deepcopy(val.output)
    df_.dataset .= key
    df = vcat(df,df_)
end
df = df[df.n .== maximum(df.n),:]
df = df[df.model .== model_,:]
filter!(:value => x -> !any(f -> f(x), (ismissing, isnothing, isnan)), df)
gdf = groupby(df, [:generator, :dataset, :n, :name, :scope])
df_plot = combine(gdf, :value => (x -> [(mean(x),mean(x)+std(x),mean(x)-std(x))]) => [:mean, :ymax, :ymin])
df_plot = df_plot[[name in [:mmd, :model_performance] for name in df_plot.name],:]
df_plot = df_plot[.!(df_plot.name.==:mmd .&& df_plot.scope.!=:model),:]
df_plot = mapcols(x -> typeof(x) == Vector{Symbol} ? string.(x) : x, df_plot)
transform!(df_plot, :dataset => (X -> [x=="cal_housing" ? "California Housing" : x for x in X]) => :dataset)
transform!(df_plot, :dataset => (X -> [x=="credit_default" ? "Credit Default" : x for x in X]) => :dataset)
transform!(df_plot, :dataset => (X -> [x=="gmsc" ? "GMSC" : x for x in X]) => :dataset)
transform!(df_plot, :name => (X -> [x=="mmd" ? "MMD (model)" : x for x in X]) => :name)
transform!(df_plot, :name => (X -> [x=="model_performance" ? "Performance" : x for x in X]) => :name)
transform!(df_plot, :generator => (X -> [x=="Generic" ? "Generic (γ=0.5)" : x for x in X]) => :generator)
transform!(df_plot, :generator => (X -> [x=="Generic_conservative" ? "Generic (γ=0.9)" : x for x in X]) => :generator)

ncol = length(unique(df_plot.dataset))
nrow = length(unique(df_plot.name))

using RCall
scale_ = 2.0
R"""
library(ggplot2)
plt <- ggplot($df_plot) +
    geom_bar(aes(x=n, y=mean, fill=generator), stat="identity", alpha=0.5, position="dodge") +
    geom_pointrange( aes(x=n, y=mean, ymin=ymin, ymax=ymax, colour=generator), alpha=0.9, position=position_dodge(width=0.9), size=0.5) +
    facet_grid(
        rows = vars(name),
        cols =  vars(dataset), 
        scales = "free_y"
    ) +
    labs(y = "Value") + 
    scale_fill_discrete(name="Generator:") +
    scale_colour_discrete(name="Generator:") +
    theme(
        axis.title.x=element_blank(),
        axis.text.x=element_blank(),
        axis.ticks.x=element_blank(),
        legend.position="bottom"
    ) +
    guides(fill=guide_legend(ncol=3))
temp_path <- file.path(tempdir(), "plot.png")
ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.85) 
"""

img = Images.load(rcopy(R"temp_path"))
Images.save(joinpath(www_path,"paper_real_world_results.png"), img)

Generators

Gravitational Generator

Code
using Pkg; Pkg.activate("dev")
Code
include("dev/utils.jl")
using AlgorithmicRecourseDynamics
using CounterfactualExplanations, Flux, Plots, PlotThemes, Random, LaplaceRedux, LinearAlgebra
theme(:wong)
output_path = output_dir("generator")
www_path = www_dir("generator")

GravitationalGenerator

Code
using MLJ
N = 1000
X, ys = make_blobs(N, 2; centers=2, as_table=false, center_box=(-5 => 5), cluster_std=0.5)
ys .= ys.==2
X = X'
xs = Flux.unstack(X,2)
data = zip(xs,ys)
counterfactual_data = CounterfactualData(X,ys')
Code
using Flux
nn = Chain(Dense(2,1))
using Flux.Optimise: update!, ADAM
opt = ADAM()
epochs = 100
loss(x, y) = Flux.Losses.logitbinarycrossentropy(nn(x), y)
avg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))
show_every = epochs/10
for epoch = 1:epochs
  for d in data
    gs = gradient(Flux.params(nn)) do
      l = loss(d...)
    end
    update!(opt, Flux.params(nn), gs)
  end
  if epoch % show_every == 0
    println("Epoch " * string(epoch))
    @show avg_loss(data)
  end
end
Code
M = FluxModel(nn)
Code
x = select_factual(counterfactual_data, rand(1:size(X)[2])) 
y = round(probs(M, x)[1])
target = ifelse(y==1.0,0.0,1.0) # opposite label as target
T = 100
Code
Λ₂ = [0.1, 1, 5]
counterfactuals_strict = []
generators = []
for λ₂  Λ₂  
    λ = [0.1, λ₂]
    generator = GravitationalGenerator=λ)
    generators = vcat(generators..., generator)
    counterfactuals_strict = vcat(
      counterfactuals_strict...,
      generate_counterfactual(x, target, counterfactual_data, M, generator; convergence=:strict, T=T)
    )
end
Code
theme(:wong)
T_ = 500
plts = []
for i  1:length(Λ₂)
    λ₂ = Λ₂[i]
    counterfactual = counterfactuals_strict[i]  
    plt = plot(counterfactual, plot_up_to=minimum([T,T_]), title="λ₂=$(λ₂)")
    x1,x2 = generators[i].centroid[1], generators[i].centroid[2]
    scatter!(plt, [x1], [x2], colour=:purple, label="attractor")
    plts = vcat(plts..., plt)
end
plt = plot(plts..., size=(1200,300), layout=(1,3))
savefig(plt, joinpath(www_path,"gravitational_generator_strict.png"))

Code
Λ₂ = [0.1, 1, 5]
counterfactuals = []
generators = []
for λ₂  Λ₂  
    λ = [0.1, λ₂]
    generator = GravitationalGenerator=λ)
    generators = vcat(generators..., generator)
    counterfactuals = vcat(
      counterfactuals...,
      generate_counterfactual(x, target, counterfactual_data, M, generator)
    )
end
Code
theme(:wong)
T_ = 500
plts = []
for i  1:length(Λ₂)
    λ₂ = Λ₂[i]
    counterfactual = counterfactuals[i]  
    plt = plot(counterfactual, plot_up_to=minimum([T,T_]), title="λ₂=$(λ₂)")
    x1,x2 = generators[i].centroid[1], generators[i].centroid[2]
    scatter!(plt, [x1], [x2], colour=:purple, label="attractor")
    plts = vcat(plts..., plt)
end
plt = plot(plts..., size=(1400,300), layout=(1,3))
savefig(plt, joinpath(www_path,"gravitational_generator.png"))

Comparison - simple vs. strict convergence

Code
idx = 1
x1,x2 = generators[idx].centroid[1], generators[idx].centroid[2]
plt1 = plot(counterfactuals[idx], plot_up_to=minimum([T,T_]), title="Simple")
scatter!(plt1, [x1], [x2], colour=:purple, label="attractor")
plt2 = plot(counterfactuals_strict[idx], plot_up_to=minimum([T,T_]), title="Strict")
scatter!(plt2, [x1], [x2], colour=:purple, label="attractor")
plt = plot(plt1, plt2, size=(850,350), layout=(1,2))
savefig(plt, joinpath(www_path,"gravitational_generator_comparison.png"))

References

ClapROAR Generator

Code
using Pkg; Pkg.activate("dev")
Code
include("dev/utils.jl")
using AlgorithmicRecourseDynamics
using CounterfactualExplanations, Flux, Plots, PlotThemes, Random, LaplaceRedux, LinearAlgebra
theme(:wong)
output_path = output_dir("generator")
www_path = www_dir("generator")

ClapROARGenerator

Code
using MLJ
N = 1000
X, ys = make_blobs(N, 2; centers=2, as_table=false, center_box=(-5 => 5), cluster_std=0.1)
ys .= ys.==2
X = X'
xs = Flux.unstack(X,2)
data = zip(xs,ys)
counterfactual_data = CounterfactualData(X,ys')
Code
using Flux
nn = Chain(Dense(2,1))
using Flux.Optimise: update!, ADAM
opt = ADAM()
epochs = 100
loss(x, y) = Flux.Losses.logitbinarycrossentropy(nn(x), y)
avg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))
show_every = epochs/10
grad_norms = []
using LinearAlgebra
for epoch = 1:epochs
    grads_ = []
    for d in data
        gs = gradient(Flux.params(nn)) do
            l = loss(d...)
        end
        update!(opt, Flux.params(nn), gs)
        grads_ = vcat(grads_..., gs)
    end
    if epoch % show_every == 0
        println("Epoch " * string(epoch))
        @show avg_loss(data)
    end
    grad_norms = vcat(grad_norms..., norm(grads_))
end
Code
M = FluxModel(nn)
Code
x = select_factual(counterfactual_data, rand(1:size(X)[2])) 
y = round(probs(M, x)[1])
target = ifelse(y==1.0,0.0,1.0) # opposite label as target
T = 100
Code
generator = GenericGenerator()
ce = generate_counterfactual(x, target,counterfactual_data, M, generator)
Code
x_ = path(ce)
y_ = CounterfactualExplanations.Counterfactuals.counterfactual_label_path(ce)
data_ = zip(x_,y_)
grad_norms = []
loss_ = []
for d in data_
    gs = gradient(Flux.params(nn)) do
        loss(d...)
    end
    loss_ = vcat(loss_...,loss(d...))
    grad_norms = vcat(grad_norms..., norm(gs))
end
Code
Λ₂ = [0.1, 1, 5]
counterfactuals_strict = []
generators = []
for λ₂  Λ₂  
    λ = [0.1, λ₂]
    generator = ClapROARGenerator=λ)
    generators = vcat(generators..., generator)
    counterfactuals_strict = vcat(
      counterfactuals_strict...,
      generate_counterfactual(x, target, counterfactual_data, M, generator; T=T)
    )
end
Code
theme(:wong)
T_ = 500
plts = []
for i  1:length(Λ₂)
    λ₂ = Λ₂[i]
    counterfactual = counterfactuals_strict[i]  
    plt = plot(counterfactual, plot_up_to=minimum([T,T_]), title="λ₂=$(λ₂)")
    plts = vcat(plts..., plt)
end
plt = plot(plts..., size=(1200,300), layout=(1,3))
# savefig(plt, joinpath(www_path,"endo_roar_generator.png"))
Code
generators = Dict(
    :Generic => GenericGenerator(),
    :ROAR => ClapROARGenerator()
)
counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactual_data, M, gen; T=T) for (name, gen) in generators])
Code
plts = []
for (name,ce)  counterfactuals
    plt = plot(ce, plot_up_to=minimum([T,T_]), title=name)
    plts = vcat(plts..., plt)
end
plt = plot(plts..., size=(800,300), layout=(1,2))